{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "Requirement already satisfied: datasets in /opt/homebrew/lib/python3.11/site-packages (2.14.7)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (1.26.2)\n",
      "Requirement already satisfied: pyarrow>=8.0.0 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (12.0.0)\n",
      "Requirement already satisfied: pyarrow-hotfix in /opt/homebrew/lib/python3.11/site-packages (from datasets) (0.6)\n",
      "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (0.3.6)\n",
      "Requirement already satisfied: pandas in /opt/homebrew/lib/python3.11/site-packages (from datasets) (2.1.4)\n",
      "Requirement already satisfied: requests>=2.19.0 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (2.31.0)\n",
      "Requirement already satisfied: tqdm>=4.62.1 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (4.66.1)\n",
      "Requirement already satisfied: xxhash in /opt/homebrew/lib/python3.11/site-packages (from datasets) (3.2.0)\n",
      "Requirement already satisfied: multiprocess in /opt/homebrew/lib/python3.11/site-packages (from datasets) (0.70.14)\n",
      "Requirement already satisfied: fsspec<=2023.10.0,>=2023.1.0 in /opt/homebrew/lib/python3.11/site-packages (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets) (2023.10.0)\n",
      "Requirement already satisfied: aiohttp in /opt/homebrew/lib/python3.11/site-packages (from datasets) (3.9.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (0.19.4)\n",
      "Requirement already satisfied: packaging in /opt/homebrew/lib/python3.11/site-packages (from datasets) (23.2)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /opt/homebrew/lib/python3.11/site-packages (from datasets) (6.0.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->datasets) (23.1.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->datasets) (6.0.4)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->datasets) (1.9.3)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->datasets) (1.4.0)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n",
      "Requirement already satisfied: filelock in /opt/homebrew/lib/python3.11/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/homebrew/lib/python3.11/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.8.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (3.6)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (2.1.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (2023.11.17)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/homebrew/lib/python3.11/site-packages (from pandas->datasets) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas->datasets) (2023.3)\n",
      "Requirement already satisfied: tzdata>=2022.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas->datasets) (2023.3)\n",
      "Requirement already satisfied: six>=1.5 in /opt/homebrew/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%pip install datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's get the dataset and see what it looks like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'task_id': 'HumanEval/0',\n",
       " 'prompt': 'from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n',\n",
       " 'canonical_solution': '    for idx, elem in enumerate(numbers):\\n        for idx2, elem2 in enumerate(numbers):\\n            if idx != idx2:\\n                distance = abs(elem - elem2)\\n                if distance < threshold:\\n                    return True\\n\\n    return False\\n',\n",
       " 'test': \"\\n\\nMETADATA = {\\n    'author': 'jt',\\n    'dataset': 'test'\\n}\\n\\n\\ndef check(candidate):\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n    assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\\n\",\n",
       " 'entry_point': 'has_close_elements'}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import datasets\n",
    "ds = datasets.load_dataset(\"openai_humaneval\")\n",
    "ds['test'][0]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we try to solve the problem, let's just load a language model and make sure everything works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction(\n",
      "    answer='Paris'\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "import dspy, dotenv, os\n",
    "dotenv.load_dotenv(os.path.expanduser(\"~/.env\"))  # load OpenAI API key from .env file\n",
    "lm = dspy.OpenAI(model=\"gpt-3.5-turbo\", max_tokens=4000)\n",
    "dspy.settings.configure(lm=lm)\n",
    "\n",
    "predictor = dspy.Predict(\"question -> answer\")\n",
    "print(predictor(question=\"What is the capital of France?\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next let's write a program that actually outputs code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Now parsing: '{\"code\": \"from typing import List\\\\n\\\\n\\\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\\\n    \\\\\"\\\\\"\\\\\" Check if in given list of numbers, are any two numbers closer to each other than\\\\n    given threshold.\\\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\\\n    False\\\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\\\n    True\\\\n    \\\\\"\\\\\"\\\\\"\\\\n    for i in range(len(numbers)):\\\\n        for j in range(i+1, len(numbers)):\\\\n            if abs(numbers[i] - numbers[j]) < threshold:\\\\n                return True\\\\n    return False\\\\n\"}'\n",
      "Parsed: PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\\n')\n",
      "is this the problem?\n",
      "kwargs={'prompt': PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n'), 'test': PythonCode(code=\"\\n\\nMETADATA = {\\n    'author': 'jt',\\n    'dataset': 'test'\\n}\\n\\n\\ndef check(candidate):\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n    assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\\n\"), 'entry_point': 'has_close_elements'}\n",
      "parsed={'solution': PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\\n')}\n",
      "after wards\n",
      "Prediction(\n",
      "    solution=PythonCode(code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\\n')\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from dspy import InputField, OutputField, Signature\n",
    "from dspy.functional import TypedPredictor\n",
    "import pydantic\n",
    "\n",
    "# We define a pydantic type that automatically checks if it's argument is valid python code.\n",
    "class PythonCode(pydantic.BaseModel):\n",
    "    code: str\n",
    "\n",
    "    @pydantic.field_validator('code')\n",
    "    def check_syntax(cls, v):\n",
    "        try:\n",
    "            # Attempt to compile the code snippet\n",
    "            compile(v, \"<string>\", \"exec\")\n",
    "        except SyntaxError as e:\n",
    "            # If a SyntaxError is raised, the code is not syntactically valid\n",
    "            raise ValueError(f\"Code is not syntactically valid: {e}\")\n",
    "            \n",
    "        return v\n",
    "\n",
    "# The signature is the main DSpy object. Note that we have types for the input and output fields,\n",
    "# which was not possible beofore.\n",
    "class CodeSignature(Signature):\n",
    "    prompt: PythonCode = InputField()\n",
    "    test: PythonCode = InputField()\n",
    "    entry_point: str = InputField()\n",
    "    solution: PythonCode = OutputField()\n",
    "\n",
    "predictor = TypedPredictor(CodeSignature)\n",
    "prediction = predictor(\n",
    "    prompt=PythonCode(code=ds['test'][0]['prompt']),\n",
    "    test=PythonCode(code=ds['test'][0]['test']),\n",
    "    entry_point=ds['test'][0]['entry_point']\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see what's happening under the hood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "Given the fields `prompt`, `test`, `entry_point`, produce the fields `solution`.\n",
      "\n",
      "---\n",
      "\n",
      "Follow the following format.\n",
      "\n",
      "Prompt: ${prompt}\n",
      "\n",
      "Test: ${test}\n",
      "\n",
      "Entry Point: ${entry_point}\n",
      "\n",
      "Past Error (solution): An error to avoid in the future\n",
      "\n",
      "Past Error (solution, 2): An error to avoid in the future\n",
      "\n",
      "Solution:\n",
      "${solution}. Respond with a single JSON object. \n",
      "You MUST use this format: {\"code\": \"print('Hello, World!')\"}\n",
      "JSON Schema: {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n",
      "\n",
      "---\n",
      "\n",
      "Prompt: code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n'\n",
      "\n",
      "Test: {\"code\":\"\\n\\nMETADATA = {\\n    'author': 'jt',\\n    'dataset': 'test'\\n}\\n\\n\\ndef check(candidate):\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n    assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\\n\"}\n",
      "\n",
      "Entry Point: has_close_elements\n",
      "\n",
      "Past Error (solution): Input should be a valid string: prompt (error type: string_type)\n",
      "\n",
      "Past Error (solution, 2): Value error, Code is not syntactically valid: unexpected character after line continuation character (<string>, line 1): code (error type: value_error)\n",
      "\n",
      "Solution:\u001b[32m {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\"}\u001b[0m\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "Given the fields `prompt`, `test`, `entry_point`, produce the fields `solution`.\n",
      "\n",
      "---\n",
      "\n",
      "Follow the following format.\n",
      "\n",
      "Prompt: ${prompt}\n",
      "\n",
      "Test: ${test}\n",
      "\n",
      "Entry Point: ${entry_point}\n",
      "\n",
      "Past Error (solution): An error to avoid in the future\n",
      "\n",
      "Past Error (solution, 2): An error to avoid in the future\n",
      "\n",
      "Past Error (solution, 3): An error to avoid in the future\n",
      "\n",
      "Solution:\n",
      "${solution}. Respond with a single JSON object. \n",
      "You MUST use this format: {\"code\": \"print('Hello, World!')\"}\n",
      "JSON Schema: {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n",
      "\n",
      "---\n",
      "\n",
      "Prompt: code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n'\n",
      "\n",
      "Test: {\"code\":\"\\n\\nMETADATA = {\\n    'author': 'jt',\\n    'dataset': 'test'\\n}\\n\\n\\ndef check(candidate):\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n    assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\\n\"}\n",
      "\n",
      "Entry Point: has_close_elements\n",
      "\n",
      "Past Error (solution): Input should be a valid string: prompt (error type: string_type)\n",
      "\n",
      "Past Error (solution, 2): Value error, Code is not syntactically valid: unexpected character after line continuation character (<string>, line 1): code (error type: value_error)\n",
      "\n",
      "Past Error (solution, 3): Input should be a valid string: prompt (error type: string_type)\n",
      "\n",
      "Solution:\u001b[32m {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\\\n    for i in range(len(numbers)):\\\\n        for j in range(i+1, len(numbers)):\\\\n            if abs(numbers[i] - numbers[j]) < threshold:\\\\n                return True\\\\n    return False\"}\u001b[0m\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "Given the fields `prompt`, `test`, `entry_point`, produce the fields `solution`.\n",
      "\n",
      "---\n",
      "\n",
      "Follow the following format.\n",
      "\n",
      "Prompt: ${prompt}\n",
      "\n",
      "Test: ${test}\n",
      "\n",
      "Entry Point: ${entry_point}\n",
      "\n",
      "Past Error (solution): An error to avoid in the future\n",
      "\n",
      "Past Error (solution, 2): An error to avoid in the future\n",
      "\n",
      "Past Error (solution, 3): An error to avoid in the future\n",
      "\n",
      "Past Error (solution, 4): An error to avoid in the future\n",
      "\n",
      "Solution:\n",
      "${solution}. Respond with a single JSON object. \n",
      "You MUST use this format: {\"code\": \"print('Hello, World!')\"}\n",
      "JSON Schema: {\"properties\": {\"code\": {\"title\": \"Code\", \"type\": \"string\"}}, \"required\": [\"code\"], \"title\": \"PythonCode\", \"type\": \"object\"}\n",
      "\n",
      "---\n",
      "\n",
      "Prompt: code='from typing import List\\n\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\\n    given threshold.\\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n    False\\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n    True\\n    \"\"\"\\n'\n",
      "\n",
      "Test: {\"code\":\"\\n\\nMETADATA = {\\n    'author': 'jt',\\n    'dataset': 'test'\\n}\\n\\n\\ndef check(candidate):\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n    assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\\n\"}\n",
      "\n",
      "Entry Point: has_close_elements\n",
      "\n",
      "Past Error (solution): Input should be a valid string: prompt (error type: string_type)\n",
      "\n",
      "Past Error (solution, 2): Value error, Code is not syntactically valid: unexpected character after line continuation character (<string>, line 1): code (error type: value_error)\n",
      "\n",
      "Past Error (solution, 3): Input should be a valid string: prompt (error type: string_type)\n",
      "\n",
      "Past Error (solution, 4): Value error, Code is not syntactically valid: unexpected character after line continuation character (<string>, line 1): code (error type: value_error)\n",
      "\n",
      "Solution:\u001b[32m {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\\\n    for i in range(len(numbers)):\\\\n        for j in range(i+1, len(numbers)):\\\\n            if abs(numbers[i] - numbers[j]) < threshold:\\\\n                return True\\\\n    return False\"}\u001b[0m\n",
      "\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lm.inspect_history(n=3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def has_close_elements(numbers: List[float], threshold: float) -> bool:\n",
      "    for i in range(len(numbers)):\n",
      "        for j in range(i+1, len(numbers)):\n",
      "            if abs(numbers[i] - numbers[j]) < threshold:\n",
      "                return True\n",
      "    return False\n"
     ]
    }
   ],
   "source": [
    "d = {\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\"}\n",
    "print(d[\"code\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "JSONDecodeError",
     "evalue": "Invalid control character at: line 1 column 82 (char 81)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mJSONDecodeError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjson\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m{\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcode\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m: \u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdef has_close_elements(numbers: List[float], threshold: float) -> bool:\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m    for i in range(len(numbers)):\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m        for j in range(i+1, len(numbers)):\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m            if abs(numbers[i] - numbers[j]) < threshold:\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m                return True\u001b[39;49m\u001b[38;5;130;43;01m\\n\u001b[39;49;00m\u001b[38;5;124;43m    return False\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m}\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m    341\u001b[0m     s \u001b[38;5;241m=\u001b[39m s\u001b[38;5;241m.\u001b[39mdecode(detect_encoding(s), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msurrogatepass\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m    344\u001b[0m         parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m    345\u001b[0m         parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    348\u001b[0m     \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m JSONDecoder\n",
      "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m    332\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, s, _w\u001b[38;5;241m=\u001b[39mWHITESPACE\u001b[38;5;241m.\u001b[39mmatch):\n\u001b[1;32m    333\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m    334\u001b[0m \u001b[38;5;124;03m    containing a JSON document).\u001b[39;00m\n\u001b[1;32m    335\u001b[0m \n\u001b[1;32m    336\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m     obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    338\u001b[0m     end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n\u001b[1;32m    339\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m end \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(s):\n",
      "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/json/decoder.py:353\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m    344\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Decode a JSON document from ``s`` (a ``str`` beginning with\u001b[39;00m\n\u001b[1;32m    345\u001b[0m \u001b[38;5;124;03ma JSON document) and return a 2-tuple of the Python\u001b[39;00m\n\u001b[1;32m    346\u001b[0m \u001b[38;5;124;03mrepresentation and the index in ``s`` where the document ended.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    350\u001b[0m \n\u001b[1;32m    351\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    352\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 353\u001b[0m     obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan_once\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m    355\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m JSONDecodeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpecting value\u001b[39m\u001b[38;5;124m\"\u001b[39m, s, err\u001b[38;5;241m.\u001b[39mvalue) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "\u001b[0;31mJSONDecodeError\u001b[0m: Invalid control character at: line 1 column 82 (char 81)"
     ]
    }
   ],
   "source": [
    "import json\n",
    "json.loads('{\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\"}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'code': 'def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import ujson\n",
    "ujson.loads('{\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\"} ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'code': 'def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False'}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "json.loads(ujson.dumps(ujson.loads('{\"code\": \"def has_close_elements(numbers: List[float], threshold: float) -> bool:\\n    for i in range(len(numbers)):\\n        for j in range(i+1, len(numbers)):\\n            if abs(numbers[i] - numbers[j]) < threshold:\\n                return True\\n    return False\"} ')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see `functional` first created an example value {\"code\": \"print('Hello, World!')\"}, which can be useful to boostrap the json generation.\n",
    "After that it still failed to generate valid json.\n",
    "It apparently decided to first repeat the schema, and then give the actual code \"as an example\"\n",
    "The validator caught the error, and gave it as a \"Past Error\", which made the model finally output a valid output."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need a way to run code. This is actually super tricky to do right in python (see https://stackoverflow.com/questions/3068139/how-can-i-sandbox-python-in-pure-python), so we'll just YOLO and call \"exec\" with globals={}."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n",
      "None\n",
      "AssertionError()\n"
     ]
    }
   ],
   "source": [
    "from repl import execute_code\n",
    "print(execute_code(\"print(3)\"))\n",
    "print(execute_code(\"assert False\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's run the evaluator on all the \"canonical solutions\" from HumanEval to check that everything is working."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Score with the original model:\n",
      "100.0\n"
     ]
    }
   ],
   "source": [
    "from dspy import Example\n",
    "\n",
    "devset = [Example(\n",
    "    prompt=PythonCode(code=test['prompt']),\n",
    "    test=PythonCode(code=test['test']),\n",
    "    entry_point=test['entry_point'],\n",
    "    solution=PythonCode(code=test['prompt']+test['canonical_solution']),\n",
    ").with_inputs('prompt', 'test', 'entry_point') for test in ds['test']]\n",
    "\n",
    "trainset = devset[:40]\n",
    "testset = devset[40:]\n",
    "\n",
    "def test_code(timeout=5):\n",
    "    def metric(example, pred, trace=None):\n",
    "        if pred.solution.code is None:\n",
    "            return 0\n",
    "        error = execute_code(\n",
    "            \"from typing import List\\n\"\n",
    "            + f\"{pred.solution.code}\\n\"\n",
    "            + f\"{example.test.code}\\n\"\n",
    "            + f\"check({example.entry_point})\",\n",
    "            timeout=timeout,\n",
    "        )\n",
    "        return int(error is None)\n",
    "    return metric\n",
    "\n",
    "metric5s = test_code(timeout=5)\n",
    "\n",
    "print(\"Score with the original model:\")\n",
    "metrix = test_code(timeout=5)\n",
    "print(100 * sum(metric5s(example, example) for example in testset) / len(testset))\n",
    "\n",
    "for example in devset:\n",
    "    if not metric5s(example, example):\n",
    "        print(\"Bad example:\")\n",
    "        code = (\n",
    "            \"from typing import List\\n\"\n",
    "            + f\"{example.solution.code}\\n\"\n",
    "            + f\"{example.test.code}\\n\"\n",
    "            + f\"check({example.entry_point})\"\n",
    "        )\n",
    "        print(code)\n",
    "        error = execute_code(code)\n",
    "        print(f\"{error=}\")\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now test our program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/124 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 0.0 / 3  (0.0):   2%|▏         | 2/124 [00:00<00:01, 77.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 24.0 / 37  (64.9):  29%|██▉       | 36/124 [00:00<00:01, 81.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 25.0 / 40  (62.5):  31%|███▏      | 39/124 [00:00<00:01, 81.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 45.0 / 80  (56.2):  64%|██████▎   | 79/124 [00:00<00:00, 109.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 45.0 / 81  (55.6):  65%|██████▍   | 80/124 [00:00<00:00, 109.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 64.0 / 119  (53.8):  95%|█████████▌| 118/124 [00:01<00:00, 110.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37\n",
      "15\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 66.0 / 124  (53.2): 100%|██████████| 124/124 [00:01<00:00, 116.06it/s]\n",
      "/Users/ahle/repos/dspy/dspy/evaluate/evaluate.py:143: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
      "  df = df.applymap(truncate_cell)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 66.0 / 124  (53.2%)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_ac8ed th {\n",
       "  text-align: left;\n",
       "}\n",
       "#T_ac8ed td {\n",
       "  text-align: left;\n",
       "}\n",
       "#T_ac8ed_row0_col0, #T_ac8ed_row0_col1, #T_ac8ed_row0_col2, #T_ac8ed_row0_col3, #T_ac8ed_row0_col4, #T_ac8ed_row0_col5, #T_ac8ed_row0_col6, #T_ac8ed_row1_col0, #T_ac8ed_row1_col1, #T_ac8ed_row1_col2, #T_ac8ed_row1_col3, #T_ac8ed_row1_col4, #T_ac8ed_row1_col5, #T_ac8ed_row1_col6, #T_ac8ed_row2_col0, #T_ac8ed_row2_col1, #T_ac8ed_row2_col2, #T_ac8ed_row2_col3, #T_ac8ed_row2_col4, #T_ac8ed_row2_col5, #T_ac8ed_row2_col6, #T_ac8ed_row3_col0, #T_ac8ed_row3_col1, #T_ac8ed_row3_col2, #T_ac8ed_row3_col3, #T_ac8ed_row3_col4, #T_ac8ed_row3_col5, #T_ac8ed_row3_col6, #T_ac8ed_row4_col0, #T_ac8ed_row4_col1, #T_ac8ed_row4_col2, #T_ac8ed_row4_col3, #T_ac8ed_row4_col4, #T_ac8ed_row4_col5, #T_ac8ed_row4_col6 {\n",
       "  text-align: left;\n",
       "  white-space: pre-wrap;\n",
       "  word-wrap: break-word;\n",
       "  max-width: 400px;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_ac8ed\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_ac8ed_level0_col0\" class=\"col_heading level0 col0\" >prompt</th>\n",
       "      <th id=\"T_ac8ed_level0_col1\" class=\"col_heading level0 col1\" >test</th>\n",
       "      <th id=\"T_ac8ed_level0_col2\" class=\"col_heading level0 col2\" >entry_point</th>\n",
       "      <th id=\"T_ac8ed_level0_col3\" class=\"col_heading level0 col3\" >example_solution</th>\n",
       "      <th id=\"T_ac8ed_level0_col4\" class=\"col_heading level0 col4\" >pred_solution</th>\n",
       "      <th id=\"T_ac8ed_level0_col5\" class=\"col_heading level0 col5\" >metric</th>\n",
       "      <th id=\"T_ac8ed_level0_col6\" class=\"col_heading level0 col6\" >solution</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_ac8ed_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_ac8ed_row0_col0\" class=\"data row0 col0\" >code='\\n\\ndef triples_sum_to_zero(l: list):\\n \"\"\"\\n triples_sum_to_zero takes a list of integers as an input.\\n it returns True if there are three distinct elements in the list...</td>\n",
       "      <td id=\"T_ac8ed_row0_col1\" class=\"data row0 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate([1, 3, 5, 0]) == False\\n assert candidate([1, 3, 5, -1]) == False\\n assert candidate([1, 3, -2, 1]) == True\\n...</td>\n",
       "      <td id=\"T_ac8ed_row0_col2\" class=\"data row0 col2\" >triples_sum_to_zero</td>\n",
       "      <td id=\"T_ac8ed_row0_col3\" class=\"data row0 col3\" >code='\\n\\ndef triples_sum_to_zero(l: list):\\n \"\"\"\\n triples_sum_to_zero takes a list of integers as an input.\\n it returns True if there are three distinct elements in the list...</td>\n",
       "      <td id=\"T_ac8ed_row0_col4\" class=\"data row0 col4\" >code='def triples_sum_to_zero(l: list):\\n for i in range(len(l)):\\n for j in range(i+1, len(l)):\\n for k in range(j+1, len(l)):\\n if l[i] + l[j] + l[k] == 0:\\n...</td>\n",
       "      <td id=\"T_ac8ed_row0_col5\" class=\"data row0 col5\" >1.0</td>\n",
       "      <td id=\"T_ac8ed_row0_col6\" class=\"data row0 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_ac8ed_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "      <td id=\"T_ac8ed_row1_col0\" class=\"data row1 col0\" >code='\\n\\ndef car_race_collision(n: int):\\n \"\"\"\\n Imagine a road that\\'s a perfectly straight infinitely long line.\\n n cars are driving left to right; simultaneously, a different set...</td>\n",
       "      <td id=\"T_ac8ed_row1_col1\" class=\"data row1 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n    assert candidate(2) == 4\\n    assert candidate(3) == 9\\n    assert candidate(4) == 16\\n    assert candidate(8) == 64\\n    assert candidate(10) == 100\\n\\n'</td>\n",
       "      <td id=\"T_ac8ed_row1_col2\" class=\"data row1 col2\" >car_race_collision</td>\n",
       "      <td id=\"T_ac8ed_row1_col3\" class=\"data row1 col3\" >code='\\n\\ndef car_race_collision(n: int):\\n \"\"\"\\n Imagine a road that\\'s a perfectly straight infinitely long line.\\n n cars are driving left to right; simultaneously, a different set...</td>\n",
       "      <td id=\"T_ac8ed_row1_col4\" class=\"data row1 col4\" >code='def car_race_collision(n: int):\\n    return n ** 2'</td>\n",
       "      <td id=\"T_ac8ed_row1_col5\" class=\"data row1 col5\" >1.0</td>\n",
       "      <td id=\"T_ac8ed_row1_col6\" class=\"data row1 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_ac8ed_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_ac8ed_row2_col0\" class=\"data row2 col0\" >code='\\n\\ndef incr_list(l: list):\\n \"\"\"Return list with elements incremented by 1.\\n >>> incr_list([1, 2, 3])\\n [2, 3, 4]\\n >>> incr_list([5, 3, 5, 2, 3, 3, 9,...</td>\n",
       "      <td id=\"T_ac8ed_row2_col1\" class=\"data row2 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate([]) == []\\n assert candidate([3, 2, 1]) == [4, 3, 2]\\n assert candidate([5, 2, 5, 2, 3, 3, 9, 0,...</td>\n",
       "      <td id=\"T_ac8ed_row2_col2\" class=\"data row2 col2\" >incr_list</td>\n",
       "      <td id=\"T_ac8ed_row2_col3\" class=\"data row2 col3\" >code='\\n\\ndef incr_list(l: list):\\n \"\"\"Return list with elements incremented by 1.\\n >>> incr_list([1, 2, 3])\\n [2, 3, 4]\\n >>> incr_list([5, 3, 5, 2, 3, 3, 9,...</td>\n",
       "      <td id=\"T_ac8ed_row2_col4\" class=\"data row2 col4\" >code='def incr_list(l: list):\\n    return [x + 1 for x in l]'</td>\n",
       "      <td id=\"T_ac8ed_row2_col5\" class=\"data row2 col5\" >1.0</td>\n",
       "      <td id=\"T_ac8ed_row2_col6\" class=\"data row2 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_ac8ed_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "      <td id=\"T_ac8ed_row3_col0\" class=\"data row3 col0\" >code='\\n\\ndef pairs_sum_to_zero(l):\\n \"\"\"\\n pairs_sum_to_zero takes a list of integers as an input.\\n it returns True if there are two distinct elements in the list that\\n...</td>\n",
       "      <td id=\"T_ac8ed_row3_col1\" class=\"data row3 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate([1, 3, 5, 0]) == False\\n assert candidate([1, 3, -2, 1]) == False\\n assert candidate([1, 2, 3, 7]) == False\\n...</td>\n",
       "      <td id=\"T_ac8ed_row3_col2\" class=\"data row3 col2\" >pairs_sum_to_zero</td>\n",
       "      <td id=\"T_ac8ed_row3_col3\" class=\"data row3 col3\" >code='\\n\\ndef pairs_sum_to_zero(l):\\n \"\"\"\\n pairs_sum_to_zero takes a list of integers as an input.\\n it returns True if there are two distinct elements in the list that\\n...</td>\n",
       "      <td id=\"T_ac8ed_row3_col4\" class=\"data row3 col4\" >code='def pairs_sum_to_zero(l):\\n    return any(-x in l for x in l if x != 0)\\n'</td>\n",
       "      <td id=\"T_ac8ed_row3_col5\" class=\"data row3 col5\" >1.0</td>\n",
       "      <td id=\"T_ac8ed_row3_col6\" class=\"data row3 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_ac8ed_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_ac8ed_row4_col0\" class=\"data row4 col0\" >code='\\n\\ndef change_base(x: int, base: int):\\n \"\"\"Change numerical base of input number x to base.\\n return string representation after the conversion.\\n base numbers are less than...</td>\n",
       "      <td id=\"T_ac8ed_row4_col1\" class=\"data row4 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate(8, 3) == \"22\"\\n assert candidate(9, 3) == \"100\"\\n assert candidate(234, 2) == \"11101010\"\\n assert candidate(16, 2) == \"10000\"\\n assert...</td>\n",
       "      <td id=\"T_ac8ed_row4_col2\" class=\"data row4 col2\" >change_base</td>\n",
       "      <td id=\"T_ac8ed_row4_col3\" class=\"data row4 col3\" >code='\\n\\ndef change_base(x: int, base: int):\\n \"\"\"Change numerical base of input number x to base.\\n return string representation after the conversion.\\n base numbers are less than...</td>\n",
       "      <td id=\"T_ac8ed_row4_col4\" class=\"data row4 col4\" >code='def change_base(x: int, base: int):\\n    return str(int(str(x), base))'</td>\n",
       "      <td id=\"T_ac8ed_row4_col5\" class=\"data row4 col5\" >0.0</td>\n",
       "      <td id=\"T_ac8ed_row4_col6\" class=\"data row4 col6\" >nan</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x2b70c79d0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <div style='\n",
       "                    text-align: center; \n",
       "                    font-size: 16px; \n",
       "                    font-weight: bold; \n",
       "                    color: #555; \n",
       "                    margin: 10px 0;'>\n",
       "                    ... 119 more rows not displayed ...\n",
       "                </div>\n",
       "                "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "53.23\n"
     ]
    }
   ],
   "source": [
    "from dspy.evaluate.evaluate import Evaluate\n",
    "evaluator = Evaluate(\n",
    "    devset=testset, num_threads=30,\n",
    "    display_progress=True,\n",
    "    display_table=5,\n",
    "    max_errors=100,\n",
    ")\n",
    "res = evaluator(predictor, metric5s)\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try to optimize it a bit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compiling...\n",
      "Going to sample between 1 and 4 traces per predictor.\n",
      "Will attempt to train 5 candidate sets.\n",
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 25.0 / 40  (62.5): 100%|██████████| 40/40 [00:01<00:00, 22.16it/s] \n",
      "/Users/ahle/repos/dspy/dspy/evaluate/evaluate.py:143: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
      "  df = df.applymap(truncate_cell)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 25.0 / 40  (62.5%)\n",
      "Score: 62.5 for set: [0]\n",
      "New best score: 62.5 for seed -3\n",
      "Scores so far: [62.5]\n",
      "Best score: 62.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 34.0 / 40  (85.0): 100%|██████████| 40/40 [00:18<00:00,  2.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n",
      "Average Metric: 34.0 / 40  (85.0%)\n",
      "Score: 85.0 for set: [8]\n",
      "New best score: 85.0 for seed -2\n",
      "Scores so far: [62.5, 85.0]\n",
      "Best score: 85.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 4/40 [00:02<00:19,  1.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrapped 4 full traces after 5 examples in round 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 33 / 40  (82.5): 100%|██████████| 40/40 [01:48<00:00,  2.72s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 33 / 40  (82.5%)\n",
      "Score: 82.5 for set: [8]\n",
      "Scores so far: [62.5, 85.0, 82.5]\n",
      "Best score: 85.0\n",
      "Average of max per entry across top 1 scores: 0.85\n",
      "Average of max per entry across top 2 scores: 0.975\n",
      "Average of max per entry across top 3 scores: 1.0\n",
      "Average of max per entry across top 5 scores: 1.0\n",
      "Average of max per entry across top 8 scores: 1.0\n",
      "Average of max per entry across top 9999 scores: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 4/40 [00:10<01:37,  2.70s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrapped 4 full traces after 5 examples in round 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 35 / 40  (87.5): 100%|██████████| 40/40 [00:12<00:00,  3.15it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 35 / 40  (87.5%)\n",
      "Score: 87.5 for set: [8]\n",
      "New best score: 87.5 for seed 0\n",
      "Scores so far: [62.5, 85.0, 82.5, 87.5]\n",
      "Best score: 87.5\n",
      "Average of max per entry across top 1 scores: 0.875\n",
      "Average of max per entry across top 2 scores: 0.975\n",
      "Average of max per entry across top 3 scores: 1.0\n",
      "Average of max per entry across top 5 scores: 1.0\n",
      "Average of max per entry across top 8 scores: 1.0\n",
      "Average of max per entry across top 9999 scores: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 2/40 [00:03<01:07,  1.79s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrapped 2 full traces after 3 examples in round 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 36 / 40  (90.0): 100%|██████████| 40/40 [00:09<00:00,  4.35it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 36 / 40  (90.0%)\n",
      "Score: 90.0 for set: [8]\n",
      "New best score: 90.0 for seed 1\n",
      "Scores so far: [62.5, 85.0, 82.5, 87.5, 90.0]\n",
      "Best score: 90.0\n",
      "Average of max per entry across top 1 scores: 0.9\n",
      "Average of max per entry across top 2 scores: 0.95\n",
      "Average of max per entry across top 3 scores: 0.975\n",
      "Average of max per entry across top 5 scores: 1.0\n",
      "Average of max per entry across top 8 scores: 1.0\n",
      "Average of max per entry across top 9999 scores: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 2/40 [00:05<01:37,  2.57s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrapped 1 full traces after 3 examples in round 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 34 / 40  (85.0): 100%|██████████| 40/40 [00:12<00:00,  3.31it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 34 / 40  (85.0%)\n",
      "Score: 85.0 for set: [8]\n",
      "Scores so far: [62.5, 85.0, 82.5, 87.5, 90.0, 85.0]\n",
      "Best score: 90.0\n",
      "Average of max per entry across top 1 scores: 0.9\n",
      "Average of max per entry across top 2 scores: 0.95\n",
      "Average of max per entry across top 3 scores: 0.975\n",
      "Average of max per entry across top 5 scores: 1.0\n",
      "Average of max per entry across top 8 scores: 1.0\n",
      "Average of max per entry across top 9999 scores: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 3/40 [00:07<01:34,  2.56s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrapped 2 full traces after 4 examples in round 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 32 / 40  (80.0): 100%|██████████| 40/40 [00:18<00:00,  2.16it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 32 / 40  (80.0%)\n",
      "Score: 80.0 for set: [8]\n",
      "Scores so far: [62.5, 85.0, 82.5, 87.5, 90.0, 85.0, 80.0]\n",
      "Best score: 90.0\n",
      "Average of max per entry across top 1 scores: 0.9\n",
      "Average of max per entry across top 2 scores: 0.95\n",
      "Average of max per entry across top 3 scores: 0.975\n",
      "Average of max per entry across top 5 scores: 1.0\n",
      "Average of max per entry across top 8 scores: 1.0\n",
      "Average of max per entry across top 9999 scores: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 2/40 [00:03<01:07,  1.76s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrapped 2 full traces after 3 examples in round 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 35 / 40  (87.5): 100%|██████████| 40/40 [00:12<00:00,  3.18it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 35 / 40  (87.5%)\n",
      "Score: 87.5 for set: [8]\n",
      "Scores so far: [62.5, 85.0, 82.5, 87.5, 90.0, 85.0, 80.0, 87.5]\n",
      "Best score: 90.0\n",
      "Average of max per entry across top 1 scores: 0.9\n",
      "Average of max per entry across top 2 scores: 0.95\n",
      "Average of max per entry across top 3 scores: 0.975\n",
      "Average of max per entry across top 5 scores: 0.975\n",
      "Average of max per entry across top 8 scores: 1.0\n",
      "Average of max per entry across top 9999 scores: 1.0\n",
      "8 candidate programs found.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from dspy.teleprompt.bootstrap import BootstrapFewShot\n",
    "from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch\n",
    "\n",
    "print(\"Compiling...\")\n",
    "compiled = BootstrapFewShotWithRandomSearch(\n",
    "#compiled = BootstrapFewShot(\n",
    "    metric=metric5s,\n",
    "    num_threads=30,\n",
    "    num_candidate_programs=5,\n",
    "    max_labeled_demos=8,\n",
    ").compile(\n",
    "    predictor,\n",
    "    trainset=trainset,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally evaluate the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 59.0 / 95  (62.1):  77%|███████▋  | 95/124 [00:16<00:04,  7.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 74.0 / 120  (61.7):  97%|█████████▋| 120/124 [00:22<00:01,  2.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 75.0 / 123  (61.0):  99%|█████████▉| 123/124 [00:25<00:00,  1.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 75.0 / 124  (60.5): 100%|██████████| 124/124 [00:30<00:00,  4.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for example in dev set: \t\t Too many retries\n",
      "Average Metric: 75.0 / 124  (60.5%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_1f859 th {\n",
       "  text-align: left;\n",
       "}\n",
       "#T_1f859 td {\n",
       "  text-align: left;\n",
       "}\n",
       "#T_1f859_row0_col0, #T_1f859_row0_col1, #T_1f859_row0_col2, #T_1f859_row0_col3, #T_1f859_row0_col4, #T_1f859_row0_col5, #T_1f859_row0_col6, #T_1f859_row1_col0, #T_1f859_row1_col1, #T_1f859_row1_col2, #T_1f859_row1_col3, #T_1f859_row1_col4, #T_1f859_row1_col5, #T_1f859_row1_col6, #T_1f859_row2_col0, #T_1f859_row2_col1, #T_1f859_row2_col2, #T_1f859_row2_col3, #T_1f859_row2_col4, #T_1f859_row2_col5, #T_1f859_row2_col6, #T_1f859_row3_col0, #T_1f859_row3_col1, #T_1f859_row3_col2, #T_1f859_row3_col3, #T_1f859_row3_col4, #T_1f859_row3_col5, #T_1f859_row3_col6, #T_1f859_row4_col0, #T_1f859_row4_col1, #T_1f859_row4_col2, #T_1f859_row4_col3, #T_1f859_row4_col4, #T_1f859_row4_col5, #T_1f859_row4_col6 {\n",
       "  text-align: left;\n",
       "  white-space: pre-wrap;\n",
       "  word-wrap: break-word;\n",
       "  max-width: 400px;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_1f859\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_1f859_level0_col0\" class=\"col_heading level0 col0\" >prompt</th>\n",
       "      <th id=\"T_1f859_level0_col1\" class=\"col_heading level0 col1\" >test</th>\n",
       "      <th id=\"T_1f859_level0_col2\" class=\"col_heading level0 col2\" >entry_point</th>\n",
       "      <th id=\"T_1f859_level0_col3\" class=\"col_heading level0 col3\" >example_solution</th>\n",
       "      <th id=\"T_1f859_level0_col4\" class=\"col_heading level0 col4\" >pred_solution</th>\n",
       "      <th id=\"T_1f859_level0_col5\" class=\"col_heading level0 col5\" >metric</th>\n",
       "      <th id=\"T_1f859_level0_col6\" class=\"col_heading level0 col6\" >solution</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_1f859_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_1f859_row0_col0\" class=\"data row0 col0\" >code='\\n\\ndef triples_sum_to_zero(l: list):\\n \"\"\"\\n triples_sum_to_zero takes a list of integers as an input.\\n it returns True if there are three distinct elements in the list...</td>\n",
       "      <td id=\"T_1f859_row0_col1\" class=\"data row0 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate([1, 3, 5, 0]) == False\\n assert candidate([1, 3, 5, -1]) == False\\n assert candidate([1, 3, -2, 1]) == True\\n...</td>\n",
       "      <td id=\"T_1f859_row0_col2\" class=\"data row0 col2\" >triples_sum_to_zero</td>\n",
       "      <td id=\"T_1f859_row0_col3\" class=\"data row0 col3\" >code='\\n\\ndef triples_sum_to_zero(l: list):\\n \"\"\"\\n triples_sum_to_zero takes a list of integers as an input.\\n it returns True if there are three distinct elements in the list...</td>\n",
       "      <td id=\"T_1f859_row0_col4\" class=\"data row0 col4\" >code='\\n\\ndef triples_sum_to_zero(l: list):\\n \"\"\"\\n triples_sum_to_zero takes a list of integers as an input.\\n it returns True if there are three distinct elements in the list...</td>\n",
       "      <td id=\"T_1f859_row0_col5\" class=\"data row0 col5\" >1.0</td>\n",
       "      <td id=\"T_1f859_row0_col6\" class=\"data row0 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_1f859_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "      <td id=\"T_1f859_row1_col0\" class=\"data row1 col0\" >code='\\n\\ndef car_race_collision(n: int):\\n \"\"\"\\n Imagine a road that\\'s a perfectly straight infinitely long line.\\n n cars are driving left to right; simultaneously, a different set...</td>\n",
       "      <td id=\"T_1f859_row1_col1\" class=\"data row1 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n    assert candidate(2) == 4\\n    assert candidate(3) == 9\\n    assert candidate(4) == 16\\n    assert candidate(8) == 64\\n    assert candidate(10) == 100\\n\\n'</td>\n",
       "      <td id=\"T_1f859_row1_col2\" class=\"data row1 col2\" >car_race_collision</td>\n",
       "      <td id=\"T_1f859_row1_col3\" class=\"data row1 col3\" >code='\\n\\ndef car_race_collision(n: int):\\n \"\"\"\\n Imagine a road that\\'s a perfectly straight infinitely long line.\\n n cars are driving left to right; simultaneously, a different set...</td>\n",
       "      <td id=\"T_1f859_row1_col4\" class=\"data row1 col4\" >code='\\n\\ndef car_race_collision(n: int):\\n \"\"\"\\n Imagine a road that\\'s a perfectly straight infinitely long line.\\n n cars are driving left to right; simultaneously, a different set...</td>\n",
       "      <td id=\"T_1f859_row1_col5\" class=\"data row1 col5\" >1.0</td>\n",
       "      <td id=\"T_1f859_row1_col6\" class=\"data row1 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_1f859_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_1f859_row2_col0\" class=\"data row2 col0\" >code='\\n\\ndef incr_list(l: list):\\n \"\"\"Return list with elements incremented by 1.\\n >>> incr_list([1, 2, 3])\\n [2, 3, 4]\\n >>> incr_list([5, 3, 5, 2, 3, 3, 9,...</td>\n",
       "      <td id=\"T_1f859_row2_col1\" class=\"data row2 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate([]) == []\\n assert candidate([3, 2, 1]) == [4, 3, 2]\\n assert candidate([5, 2, 5, 2, 3, 3, 9, 0,...</td>\n",
       "      <td id=\"T_1f859_row2_col2\" class=\"data row2 col2\" >incr_list</td>\n",
       "      <td id=\"T_1f859_row2_col3\" class=\"data row2 col3\" >code='\\n\\ndef incr_list(l: list):\\n \"\"\"Return list with elements incremented by 1.\\n >>> incr_list([1, 2, 3])\\n [2, 3, 4]\\n >>> incr_list([5, 3, 5, 2, 3, 3, 9,...</td>\n",
       "      <td id=\"T_1f859_row2_col4\" class=\"data row2 col4\" >code='\\n\\ndef incr_list(l: list):\\n \"\"\"Return list with elements incremented by 1.\\n >>> incr_list([1, 2, 3])\\n [2, 3, 4]\\n >>> incr_list([5, 3, 5, 2, 3, 3, 9,...</td>\n",
       "      <td id=\"T_1f859_row2_col5\" class=\"data row2 col5\" >1.0</td>\n",
       "      <td id=\"T_1f859_row2_col6\" class=\"data row2 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_1f859_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "      <td id=\"T_1f859_row3_col0\" class=\"data row3 col0\" >code='\\n\\ndef pairs_sum_to_zero(l):\\n \"\"\"\\n pairs_sum_to_zero takes a list of integers as an input.\\n it returns True if there are two distinct elements in the list that\\n...</td>\n",
       "      <td id=\"T_1f859_row3_col1\" class=\"data row3 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate([1, 3, 5, 0]) == False\\n assert candidate([1, 3, -2, 1]) == False\\n assert candidate([1, 2, 3, 7]) == False\\n...</td>\n",
       "      <td id=\"T_1f859_row3_col2\" class=\"data row3 col2\" >pairs_sum_to_zero</td>\n",
       "      <td id=\"T_1f859_row3_col3\" class=\"data row3 col3\" >code='\\n\\ndef pairs_sum_to_zero(l):\\n \"\"\"\\n pairs_sum_to_zero takes a list of integers as an input.\\n it returns True if there are two distinct elements in the list that\\n...</td>\n",
       "      <td id=\"T_1f859_row3_col4\" class=\"data row3 col4\" >code='\\n\\ndef pairs_sum_to_zero(l):\\n \"\"\"\\n pairs_sum_to_zero takes a list of integers as an input.\\n it returns True if there are two distinct elements in the list that\\n...</td>\n",
       "      <td id=\"T_1f859_row3_col5\" class=\"data row3 col5\" >1.0</td>\n",
       "      <td id=\"T_1f859_row3_col6\" class=\"data row3 col6\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_1f859_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_1f859_row4_col0\" class=\"data row4 col0\" >code='\\n\\ndef change_base(x: int, base: int):\\n \"\"\"Change numerical base of input number x to base.\\n return string representation after the conversion.\\n base numbers are less than...</td>\n",
       "      <td id=\"T_1f859_row4_col1\" class=\"data row4 col1\" >code='\\n\\nMETADATA = {}\\n\\n\\ndef check(candidate):\\n assert candidate(8, 3) == \"22\"\\n assert candidate(9, 3) == \"100\"\\n assert candidate(234, 2) == \"11101010\"\\n assert candidate(16, 2) == \"10000\"\\n assert...</td>\n",
       "      <td id=\"T_1f859_row4_col2\" class=\"data row4 col2\" >change_base</td>\n",
       "      <td id=\"T_1f859_row4_col3\" class=\"data row4 col3\" >code='\\n\\ndef change_base(x: int, base: int):\\n \"\"\"Change numerical base of input number x to base.\\n return string representation after the conversion.\\n base numbers are less than...</td>\n",
       "      <td id=\"T_1f859_row4_col4\" class=\"data row4 col4\" >code='\\n\\ndef change_base(x: int, base: int):\\n \"\"\"Change numerical base of input number x to base.\\n return string representation after the conversion.\\n base numbers are less than...</td>\n",
       "      <td id=\"T_1f859_row4_col5\" class=\"data row4 col5\" >1.0</td>\n",
       "      <td id=\"T_1f859_row4_col6\" class=\"data row4 col6\" >nan</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x2b2fa1410>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <div style='\n",
       "                    text-align: center; \n",
       "                    font-size: 16px; \n",
       "                    font-weight: bold; \n",
       "                    color: #555; \n",
       "                    margin: 10px 0;'>\n",
       "                    ... 119 more rows not displayed ...\n",
       "                </div>\n",
       "                "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compiled HumanEval score: 60.48\n"
     ]
    }
   ],
   "source": [
    "print(\"Evaluating...\")\n",
    "print(\n",
    "    \"Compiled HumanEval score:\",\n",
    "    evaluator(compiled, metric=test_code(timeout=100)),\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
